{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Deep Learning - Deep Vision Classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Environment Setup on databricks\n",
    "### -- reinstall horovod based on new version of pytorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# install cloudpickle 2.0.0 to add synapse module for usage of horovod\n",
    "%pip install cloudpickle==2.0.0 --force-reinstall --no-deps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import synapse\n",
    "import cloudpickle\n",
    "import os\n",
    "import urllib.request\n",
    "import zipfile\n",
    "\n",
    "cloudpickle.register_pickle_by_value(synapse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! horovodrun --check-build"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql.functions import udf, col, regexp_replace\n",
    "from pyspark.sql.types import IntegerType\n",
    "from pyspark.ml.evaluation import MulticlassClassificationEvaluator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Read Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "folder_path = \"/tmp/flowers_prepped\"\n",
    "zip_url = \"https://mmlspark.blob.core.windows.net/datasets/Flowers/flowers_prepped.zip\"\n",
    "zip_path = \"/dbfs/tmp/flowers_prepped.zip\"\n",
    "\n",
    "if not os.path.exists(\"/dbfs\" + folder_path):\n",
    "    urllib.request.urlretrieve(zip_url, zip_path)\n",
    "    with zipfile.ZipFile(zip_path, \"r\") as zip_ref:\n",
    "        zip_ref.extractall(\"/dbfs/tmp\")\n",
    "    os.remove(zip_path)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def assign_label(path):\n",
    "    num = int(path.split(\"/\")[-1].split(\".\")[0].split(\"_\")[1])\n",
    "    return num // 81\n",
    "\n",
    "\n",
    "assign_label_udf = udf(assign_label, IntegerType())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# These files are already uploaded for build test machine\n",
    "train_df = (\n",
    "    spark.read.format(\"binaryFile\")\n",
    "    .option(\"pathGlobFilter\", \"*.jpg\")\n",
    "    .load(folder_path + \"/train\")\n",
    "    .withColumn(\"image\", regexp_replace(\"path\", \"dbfs:\", \"/dbfs\"))\n",
    "    .withColumn(\"label\", assign_label_udf(col(\"path\")))\n",
    "    .select(\"image\", \"label\")\n",
    ")\n",
    "\n",
    "display(train_df.limit(100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_df = (\n",
    "    spark.read.format(\"binaryFile\")\n",
    "    .option(\"pathGlobFilter\", \"*.jpg\")\n",
    "    .load(folder_path + \"/test\")\n",
    "    .withColumn(\"image\", regexp_replace(\"path\", \"dbfs:\", \"/dbfs\"))\n",
    "    .withColumn(\"label\", assign_label_udf(col(\"path\")))\n",
    "    .select(\"image\", \"label\")\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from horovod.spark.common.store import DBFSLocalStore\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "from synapse.ml.dl import *\n",
    "import uuid\n",
    "\n",
    "run_output_dir = f\"/dbfs/FileStore/test/resnet50/{str(uuid.uuid4())[:8]}\"\n",
    "store = DBFSLocalStore(run_output_dir)\n",
    "\n",
    "epochs = 10\n",
    "\n",
    "callbacks = [ModelCheckpoint(filename=\"{epoch}-{train_loss:.2f}\")]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "deep_vision_classifier = DeepVisionClassifier(\n",
    "    backbone=\"resnet50\",\n",
    "    store=store,\n",
    "    callbacks=callbacks,\n",
    "    num_classes=17,\n",
    "    batch_size=16,\n",
    "    epochs=epochs,\n",
    "    validation=0.1,\n",
    ")\n",
    "\n",
    "deep_vision_model = deep_vision_classifier.fit(train_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_df = deep_vision_model.transform(test_df)\n",
    "evaluator = MulticlassClassificationEvaluator(\n",
    "    predictionCol=\"prediction\", labelCol=\"label\", metricName=\"accuracy\"\n",
    ")\n",
    "print(\"Test accuracy:\", evaluator.evaluate(pred_df))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cleanup the output dir for test\n",
    "dbutils.fs.rm(run_output_dir, True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.5 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.8.5"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "601a75c4c141f401603984f1538447337114e368c54c4d5b589ea94315afdca2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
